# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Installation script."""

import os
import subprocess
import sys
import sysconfig
import copy
import time

from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Type

import setuptools

from .utils import (
    cmake_bin,
    debug_build_enabled,
    found_ninja,
    get_frameworks,
    nvcc_path,
    get_max_jobs_for_parallel_build,
)


class CMakeExtension(setuptools.Extension):
    """CMake extension module"""

    def __init__(
        self,
        name: str,
        cmake_path: Path,
        cmake_flags: Optional[List[str]] = None,
    ) -> None:
        super().__init__(name, sources=[])  # No work for base class
        self.cmake_path: Path = cmake_path
        self.cmake_flags: List[str] = [] if cmake_flags is None else cmake_flags

    def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
        # Make sure paths are str
        _cmake_bin = str(cmake_bin())
        cmake_path = str(self.cmake_path)
        build_dir = str(build_dir)
        install_dir = str(install_dir)

        # CMake configure command
        build_type = "Debug" if debug_build_enabled() else "Release"
        configure_command = [
            _cmake_bin,
            "-S",
            cmake_path,
            "-B",
            build_dir,
            f"-DPython_EXECUTABLE={sys.executable}",
            f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
            f"-DPython_SITEARCH={sysconfig.get_path('platlib')}",
            f"-DCMAKE_BUILD_TYPE={build_type}",
            f"-DCMAKE_INSTALL_PREFIX={install_dir}",
        ]
        configure_command += self.cmake_flags

        import pybind11

        pybind11_dir = Path(pybind11.__file__).resolve().parent
        pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
        configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")

        # CMake build and install commands
        build_command = [_cmake_bin, "--build", build_dir, "--verbose"]
        install_command = [_cmake_bin, "--install", build_dir, "--verbose"]

        # Check whether parallel build is restricted
        max_jobs = get_max_jobs_for_parallel_build()
        if found_ninja():
            configure_command.append("-GNinja")
        build_command.append("--parallel")
        if max_jobs > 0:
            build_command.append(str(max_jobs))

        # Run CMake commands
        start_time = time.perf_counter()
        for command in [configure_command, build_command, install_command]:
            print(f"Running command {' '.join(command)}")
            try:
                subprocess.run(command, cwd=build_dir, check=True)
            except (CalledProcessError, OSError) as e:
                raise RuntimeError(f"Error when running CMake: {e}")

        total_time = time.perf_counter() - start_time
        print(f"Time for build_ext: {total_time:.2f} seconds")


def get_build_ext(
    extension_cls: Type[setuptools.Extension], framework_extension_only: bool = False
):
    class _CMakeBuildExtension(extension_cls):
        """Setuptools command with support for CMake extension modules"""

        def run(self) -> None:
            # Build CMake extensions
            for ext in self.extensions:
                package_path = Path(self.get_ext_fullpath(ext.name))
                install_dir = package_path.resolve().parent
                if isinstance(ext, CMakeExtension):
                    print(f"Building CMake extension {ext.name}")
                    # Set up incremental builds for CMake extensions
                    build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
                    if build_dir:
                        build_dir = Path(build_dir).resolve()
                    else:
                        root_dir = Path(__file__).resolve().parent.parent
                        build_dir = root_dir / "build" / "cmake"

                    # Ensure the directory exists
                    build_dir.mkdir(parents=True, exist_ok=True)

                    ext._build_cmake(
                        build_dir=build_dir,
                        install_dir=install_dir,
                    )

            # Build non-CMake extensions as usual
            all_extensions = self.extensions
            self.extensions = [
                ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
            ]
            super().run()
            self.extensions = all_extensions

            # Ensure that shared objects files for source and PyPI installations live
            # in separate directories to avoid conflicts during install and runtime.
            lib_dir = (
                "wheel_lib"
                if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or framework_extension_only
                else ""
            )

            # Ensure that binaries are not in global package space.
            # For editable/inplace builds this is not a concern as
            # the SOs will be in a local directory anyway.
            if not self.inplace:
                target_dir = install_dir / "transformer_engine" / lib_dir
                target_dir.mkdir(exist_ok=True, parents=True)

                for ext in Path(self.build_lib).glob("*.so"):
                    self.copy_file(ext, target_dir)
                    os.remove(ext)

        def build_extensions(self):
            # For core lib + JAX install, fix build_ext from pybind11.setup_helpers
            # to handle CUDA files correctly.
            if "pytorch" not in get_frameworks():
                # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
                # extra_compile_args is a dict.
                for ext in self.extensions:
                    if isinstance(ext.extra_compile_args, dict):
                        for target in ["cxx", "nvcc"]:
                            if target not in ext.extra_compile_args.keys():
                                ext.extra_compile_args[target] = []

                # Define new _compile method that redirects to NVCC for .cu and .cuh files.
                original_compile_fn = self.compiler._compile
                if not framework_extension_only:
                    self.compiler.src_extensions += [".cu", ".cuh"]

                def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
                    # Copy before we make any modifications.
                    cflags = copy.deepcopy(extra_postargs)
                    original_compiler = self.compiler.compiler_so
                    try:
                        original_compiler = self.compiler.compiler_so

                        if (
                            os.path.splitext(src)[1] in [".cu", ".cuh"]
                            and not framework_extension_only
                        ):
                            nvcc_bin = nvcc_path()
                            self.compiler.set_executable("compiler_so", str(nvcc_bin))
                            if isinstance(cflags, dict):
                                cflags = cflags["nvcc"]

                            # Add -fPIC if not already specified
                            if not any("-fPIC" in flag for flag in cflags):
                                cflags.extend(["--compiler-options", "'-fPIC'"])

                            # Forward unknown options
                            if not any("--forward-unknown-opts" in flag for flag in cflags):
                                cflags.append("--forward-unknown-opts")
                        elif isinstance(cflags, dict):
                            cflags = cflags["cxx"]

                        # Append -std=c++17 if not already in flags
                        if not any(flag.startswith("-std=") for flag in cflags):
                            cflags.append("-std=c++17")

                        return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts)

                    finally:
                        # Put the original compiler back in place.
                        self.compiler.set_executable("compiler_so", original_compiler)

                self.compiler._compile = _compile_fn

            super().build_extensions()

    return _CMakeBuildExtension
